{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Recompiling stale cache file /data/scratch/deniz/.julia/compiled/v1.0/Knet/f4vSz.ji for Knet [1902f260-5fb4-5aff-8c31-6271790ab950]\n",
      "└ @ Base loading.jl:1184\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Knet.gpu() = 2\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# (c) Deniz Yuret, 2018.\n",
    "# Note that this is an instructional example written in low-level Julia/Knet and it is slow to train.\n",
    "# For a faster and high-level implementation please see `@doc RNN`.\n",
    "# TODO: check the 50% speed regression in julia 1.0.\n",
    "using Pkg; haskey(Pkg.installed(),\"Knet\") || Pkg.add(\"Knet\")\n",
    "using Knet; @show Knet.gpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A one layer MLP vs a simple RNN\n",
    "\n",
    "([Elman 1990](http://onlinelibrary.wiley.com/doi/10.1207/s15516709cog1402_1/pdf)) A simple RNN takes the previous hidden state as an extra input, and returns the next hidden state as an extra output."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/RNN-rolled.png\" width=\"150\" />\n",
    "([image source](http://colah.github.io/posts/2015-08-Understanding-LSTMs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Comparison of a single hidden layer MLP and corresponding RNN\n",
    "\n",
    "function mlp1(param, input)\n",
    "    hidden = tanh(input * param[1] .+ param[2])\n",
    "    output = hidden * param[3] .+ param[4]\n",
    "    return output\n",
    "end\n",
    "\n",
    "function rnn1(param, input, hidden)\n",
    "    input2 = hcat(input, hidden)\n",
    "    hidden = tanh(input2 * param[1] .+ param[2])\n",
    "    output = hidden * param[3] .+ param[4]\n",
    "    return (hidden, output)\n",
    "end;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Backpropagation through time (BPTT)\n",
    "\n",
    "([Werbos, 1988](http://www.sciencedirect.com/science/article/pii/089360808890007X))\n",
    "An RNN unrolled in time is similar to a deep feed-forward network which (i) has as many layers as time steps, (ii) has weights shared between different layers, and (iii) may have multiple inputs and outputs received and produced at individual layers. Backpropagation can be used to train RNNs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/RNN-unrolled.png\" width=800 />\n",
    "([image source](http://colah.github.io/posts/2015-08-Understanding-LSTMs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loss calculation and training.\n",
    "\n",
    "function rnnloss(param,inputs,hidden,outputs)\n",
    "    # inputs and outputs are sequences of the same length\n",
    "    sumloss = 0\n",
    "    for t in 1:length(inputs)\n",
    "        output,hidden = rnn1(param,inputs[t],hidden)\n",
    "        sumloss += loss_function(output,outputs[t])\n",
    "    end\n",
    "    return sumloss\n",
    "end\n",
    "\n",
    "rnngrad = grad(rnnloss);\n",
    "\n",
    "# ... train with our usual SGD procedure"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Long Short-Term Memory (LSTM)\n",
    "([Hochreiter and Schmidhuber, 1997](https://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf))\n",
    "LSTM is a more sophisticated RNN module that performs better with long-range dependencies. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-chain.png\" width=800 />\n",
    "([image source](http://colah.github.io/posts/2015-08-Understanding-LSTMs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "$$\\begin{align}\n",
    "f_t &= \\sigma(W_f\\cdot[h_{t-1},x_t] + b_f) & \\text{forget gate} \\\\\n",
    "i_t &= \\sigma(W_i\\cdot[h_{t-1},x_t] + b_i) & \\text{input gate} \\\\\n",
    "\\tilde{C}_t &= \\tanh(W_C\\cdot[h_{t-1},x_t] + b_C) & \\text{cell candidate} \\\\\n",
    "C_t &= f_t \\ast C_{t-1} + i_t \\ast \\tilde{C}_t & \\text{new cell} \\\\\n",
    "o_t &= \\sigma(W_o\\cdot[h_{t-1},x_t] + b_o) & \\text{output gate} \\\\\n",
    "h_t &= o_t \\ast \\tanh(C_t) & \\text{new output}\\\\\n",
    "\\end{align}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A LSTM implementation in Knet\n",
    "\n",
    "function lstm(param, state, input)\n",
    "    weight,bias = param\n",
    "    hidden,cell = state\n",
    "    h       = size(hidden,2)\n",
    "    gates   = hcat(input,hidden) * weight .+ bias\n",
    "    forget  = sigm.(gates[:,1:h])\n",
    "    ingate  = sigm.(gates[:,1+h:2h])\n",
    "    outgate = sigm.(gates[:,1+2h:3h])\n",
    "    change  = tanh.(gates[:,1+3h:4h])\n",
    "    cell    = cell .* forget + ingate .* change\n",
    "    hidden  = outgate .* tanh.(cell)\n",
    "    return (hidden,cell)\n",
    "end;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sequence to sequence model (S2S)\n",
    "([Sutskever et al. 2014](https://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf))\n",
    "S2S models learn to map input sequences to output sequences using an encoder and a decoder RNN."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"http://nzw0301.github.io/images/seq2seq.svg\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# S2S loss function and its gradient\n",
    "\n",
    "function s2s(model, inputs, outputs)\n",
    "    state = initstate(inputs[1], model[:state0])\n",
    "    for input in inputs\n",
    "        input = onehotrows(input, model[:embed1])\n",
    "        input = input * model[:embed1]\n",
    "        state = lstm(model[:encode], state, input)\n",
    "    end\n",
    "    EOS = eosmatrix(outputs[1], model[:embed2])\n",
    "    input = EOS * model[:embed2]\n",
    "    sumlogp = 0\n",
    "    for output in outputs\n",
    "        state = lstm(model[:decode], state, input)\n",
    "        ypred = predict(model[:output], state[1])\n",
    "        ygold = onehotrows(output, model[:embed2])\n",
    "        sumlogp += sum(ygold .* logp(ypred,dims=2))\n",
    "        input = ygold * model[:embed2]\n",
    "    end\n",
    "    state = lstm(model[:decode], state, input)\n",
    "    ypred = predict(model[:output], state[1])\n",
    "    sumlogp += sum(EOS .* logp(ypred,dims=2))\n",
    "    return -sumlogp\n",
    "end\n",
    "\n",
    "s2sgrad = gradloss(s2s);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"https://docs.google.com/drawings/d/1BR871g8k4jpI-mKeXiJfpY5Jl5cKcognvH7hHSugQds/pub?w=958&h=236\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# S2S model definition\n",
    "\n",
    "function initmodel(H, V; atype=(gpu()>=0 ? KnetArray{Float32} : Array{Float32}))\n",
    "    init(d...)=atype(xavier(d...))\n",
    "    model = Dict{Symbol,Any}()\n",
    "    model[:state0] = [ init(1,H), init(1,H) ]\n",
    "    model[:embed1] = init(V,H)\n",
    "    model[:encode] = [ init(2H,4H), init(1,4H) ]\n",
    "    model[:embed2] = init(V,H)\n",
    "    model[:decode] = [ init(2H,4H), init(1,4H) ]\n",
    "    model[:output] = [ init(H,V), init(1,V) ]\n",
    "    return model\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# S2S helper functions\n",
    "\n",
    "function predict(param, input)\n",
    "    input * param[1] .+ param[2]\n",
    "end\n",
    "\n",
    "function initstate(idx, state0)\n",
    "    h,c = state0\n",
    "    h = h .+ fill!(similar(value(h), length(idx), length(h)), 0)\n",
    "    c = c .+ fill!(similar(value(c), length(idx), length(c)), 0)\n",
    "    return (h,c)\n",
    "end\n",
    "\n",
    "function onehotrows(idx, embeddings)\n",
    "    nrows,ncols = length(idx), size(embeddings,1)\n",
    "    z = zeros(Float32,nrows,ncols)\n",
    "    @inbounds for i=1:nrows\n",
    "        z[i,idx[i]] = 1\n",
    "    end\n",
    "    oftype(value(embeddings),z)\n",
    "end\n",
    "\n",
    "let EOS=nothing; global eosmatrix\n",
    "function eosmatrix(idx, embeddings)\n",
    "    nrows,ncols = length(idx), size(embeddings,1)\n",
    "    if EOS==nothing || size(EOS) != (nrows,ncols)\n",
    "        EOS = zeros(Float32,nrows,ncols)\n",
    "        EOS[:,1] .= 1\n",
    "        EOS = oftype(value(embeddings), EOS)\n",
    "    end\n",
    "    return EOS\n",
    "end\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "99171-element Array{Array{Int64,1},1}\n",
      "99171-element Array{SubString{String},1}\n",
      "70-element Array{Char,1}\n",
      "Dict{Char,Int64} with 70 entries\n",
      "Alvaro's\n",
      "Alvin\n",
      "Alvin's\n",
      "Alyce\n",
      "Alyce's\n"
     ]
    }
   ],
   "source": [
    "# Use reversing English words as an example task\n",
    "# This loads them from /usr/share/dict/words and converts each character to an int.\n",
    "\n",
    "function readdata(file=\"/usr/share/dict/words\")\n",
    "    global strings = map(chomp,readlines(file))\n",
    "    global tok2int = Dict{Char,Int}()\n",
    "    global int2tok = Vector{Char}()\n",
    "    push!(int2tok,'\\n'); tok2int['\\n']=1 # We use '\\n'=>1 as the EOS token                                                 \n",
    "    sequences = Vector{Vector{Int}}()\n",
    "    for w in strings\n",
    "        s = Vector{Int}()\n",
    "        for c in collect(w)\n",
    "            if !haskey(tok2int,c)\n",
    "                push!(int2tok,c)\n",
    "                tok2int[c] = length(int2tok)\n",
    "            end\n",
    "            push!(s, tok2int[c])\n",
    "        end\n",
    "        push!(sequences, s)\n",
    "    end\n",
    "    return sequences\n",
    "end\n",
    "\n",
    "sequences = readdata();\n",
    "for x in (sequences, strings, int2tok, tok2int); println(summary(x)); end\n",
    "for x in strings[501:505]; println(x); end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"766-element Array{Any,1}\""
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Minibatch sequences putting equal length sequences together:\n",
    "\n",
    "function minibatch(sequences, batchsize)\n",
    "    table = Dict{Int,Vector{Vector{Int}}}()\n",
    "    data = Any[]\n",
    "    for s in sequences\n",
    "        n = length(s)\n",
    "        nsequences = get!(table, n, Any[])\n",
    "        push!(nsequences, s)\n",
    "        if length(nsequences) == batchsize\n",
    "            push!(data, [[ nsequences[i][j] for i in 1:batchsize] for j in 1:n ])\n",
    "            empty!(nsequences)\n",
    "        end\n",
    "    end\n",
    "    return data\n",
    "end\n",
    "\n",
    "batchsize, statesize, vocabsize = 128, 128, length(int2tok)\n",
    "data = minibatch(sequences,batchsize)\n",
    "summary(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "train (generic function with 1 method)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Training loop\n",
    "\n",
    "function train(model, data, opts)\n",
    "    sumloss = cntloss = 0\n",
    "    for sequence in data\n",
    "        grads,loss = s2sgrad(model, sequence, reverse(sequence))\n",
    "        update!(model, grads, opts)\n",
    "        sumloss += loss\n",
    "        cntloss += (1+length(sequence)) * length(sequence[1])\n",
    "    end\n",
    "    return sumloss/cntloss\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 33.763921 seconds (29.87 M allocations: 1.764 GiB, 3.86% gc time)\n",
      "(1, 2.1116824f0)\n",
      " 24.243649 seconds (17.49 M allocations: 1.127 GiB, 3.00% gc time)\n",
      "(2, 1.1684585f0)\n",
      " 22.244873 seconds (17.52 M allocations: 1.128 GiB, 3.02% gc time)\n",
      "(3, 0.5343968f0)\n",
      " 22.714329 seconds (17.52 M allocations: 1.129 GiB, 2.93% gc time)\n",
      "(4, 0.2762618f0)\n",
      " 22.857687 seconds (17.52 M allocations: 1.128 GiB, 2.88% gc time)\n",
      "(5, 0.16793005f0)\n",
      " 23.432127 seconds (17.49 M allocations: 1.127 GiB, 2.78% gc time)\n",
      "(6, 0.112778656f0)\n",
      " 23.819823 seconds (17.52 M allocations: 1.128 GiB, 2.92% gc time)\n",
      "(7, 0.07800572f0)\n",
      " 25.183425 seconds (17.52 M allocations: 1.128 GiB, 2.85% gc time)\n",
      "(8, 0.061287124f0)\n",
      " 23.019519 seconds (17.52 M allocations: 1.128 GiB, 2.88% gc time)\n",
      "(9, 0.04776922f0)\n",
      " 22.754002 seconds (17.49 M allocations: 1.127 GiB, 2.92% gc time)\n",
      "(10, 0.030740755f0)\n",
      "244.698817 seconds (188.41 M allocations: 11.960 GiB, 3.05% gc time)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"Dict{Symbol,Any} with 6 entries\""
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = opts = nothing; Knet.gc() # clean memory from previous run\n",
    "if !isfile(\"rnnreverse.jld2\")\n",
    "    # Initialize model and optimization parameters\n",
    "    model = initmodel(statesize,vocabsize)\n",
    "    opts = optimizers(model,Adam)\n",
    "    @time for epoch=1:10\n",
    "        @time loss = train(model,data,opts) # 15s (0.6.4) 22s (1.0.0) /epoch\n",
    "        println((epoch,loss))\n",
    "    end\n",
    "    Knet.save(\"rnnreverse.jld2\",\"model\",model)\n",
    "else\n",
    "    model = Knet.load(\"rnnreverse.jld2\",\"model\")\n",
    "end\n",
    "summary(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"nrocirpac\""
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Test on some examples:\n",
    "\n",
    "function translate(model, str)\n",
    "    state = model[:state0]\n",
    "    for c in collect(str)\n",
    "        input = onehotrows(tok2int[c], model[:embed1])\n",
    "        input = input * model[:embed1]\n",
    "        state = lstm(model[:encode], state, input)\n",
    "    end\n",
    "    input = eosmatrix(1, model[:embed2]) * model[:embed2]\n",
    "    output = Char[]\n",
    "    for i=1:100 #while true                                                                                                \n",
    "        state = lstm(model[:decode], state, input)\n",
    "        pred = predict(model[:output], state[1])\n",
    "        i = argmax(vec(Array(pred)))\n",
    "        i == 1 && break\n",
    "        push!(output, int2tok[i])\n",
    "        input = onehotrows(i, model[:embed2]) * model[:embed2]\n",
    "    end\n",
    "    String(output)\n",
    "end\n",
    "\n",
    "translate(model,\"capricorn\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.0.0",
   "language": "julia",
   "name": "julia-1.0"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.0.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
